"""
2024.3.12: 去掉deformable_attn后可以work的版本
"""
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, ConvModule
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
from mmdet.models.layers.transformer.mmcv_spike.multi_scale_deform_attn import SpikeMultiScaleDeformableAttention
from mmengine.model import (BaseModule, ModuleList, caffe2_xavier_init,
                            normal_init, xavier_init)
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.utils import ConfigType, OptMultiConfig
from ..task_modules.prior_generators import MlvlPointGenerator
from .positional_encoding import SinePositionalEncoding
from .transformer import Mask2FormerTransformerEncoder, Spike2FormerTransformerEncoder
from mmdet.models.utils.Qtrick import MultiSpike_norm4


@MODELS.register_module()
class MSDeformAttnPixelDecoder(BaseModule):
    """Pixel decoder with multi-scale deformable attention.

    Args:
        in_channels (list[int] | tuple[int]): Number of channels in the
            input feature maps.
        strides (list[int] | tuple[int]): Output strides of feature from
            backbone.
        feat_channels (int): Number of channels for feature.
        out_channels (int): Number of channels for output.
        num_outs (int): Number of output scales.
        norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
            Defaults to dict(type='GN', num_groups=32).
        act_cfg (:obj:`ConfigDict` or dict): Config for activation.
            Defaults to dict(type='ReLU').
        encoder (:obj:`ConfigDict` or dict): Config for transformer
            encoder. Defaults to None.
        positional_encoding (:obj:`ConfigDict` or dict): Config for
            transformer encoder position encoding. Defaults to
            dict(num_feats=128, normalize=True).
        init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
            dict], optional): Initialization config dict. Defaults to None.
    """

    def __init__(self,
                 in_channels: Union[List[int],
                                    Tuple[int]] = [256, 512, 1024, 2048],
                 strides: Union[List[int], Tuple[int]] = [4, 8, 16, 32],
                 feat_channels: int = 256,
                 out_channels: int = 256,
                 num_outs: int = 3,
                 T: int = 4,
                 norm_cfg: ConfigType = dict(type='GN', num_groups=32),
                 act_cfg: ConfigType = dict(type='ReLU'),
                 encoder: ConfigType = None,
                 positional_encoding: ConfigType = dict(
                     num_feats=128, normalize=True),
                 init_cfg: OptMultiConfig = None) -> None:
        super().__init__(init_cfg=init_cfg)
        self.strides = strides
        self.num_input_levels = len(in_channels)
        self.num_encoder_levels = \
            encoder.layer_cfg.self_attn_cfg.num_levels
        self.T=T
        assert self.num_encoder_levels >= 1, \
            'num_levels in attn_cfgs must be at least one'
        input_conv_list = []
        input_conv_spike_list = []
        # from top to down (low to high resolution)
        for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1):
            # NOTE: 这里可以参考 MS_Conv_block
            input_conv_spike = MultiSpike_norm4(T=self.T)
            input_conv = nn.Sequential(
                nn.Conv2d(in_channels[i], feat_channels, kernel_size=1, stride=1),
                nn.BatchNorm2d(feat_channels)
            )
            input_conv_spike_list.append(input_conv_spike)
            input_conv_list.append(input_conv)
            # input_conv = ConvModule(
            #     in_channels[i],
            #     feat_channels,
            #     kernel_size=1,
            #     norm_cfg=norm_cfg,
            #     act_cfg=None,
            #     bias=True)

        self.input_convs = ModuleList(input_conv_list)
        self.input_convs_spike = ModuleList(input_conv_spike_list)

        self.encoder = Mask2FormerTransformerEncoder(**encoder)
        self.postional_encoding = SinePositionalEncoding(**positional_encoding)
        # high resolution to low resolution
        self.level_encoding = nn.Embedding(self.num_encoder_levels,
                                           feat_channels)

        # fpn-like structure
        self.lateral_convs = ModuleList()
        self.lateral_convs_spikes = ModuleList()
        self.output_convs = ModuleList()
        self.output_convs_spikes = ModuleList()
        self.use_bias = norm_cfg is None
        # from top to down (low to high resolution)
        # fpn for the rest features that didn't pass in encoder
        for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
                       -1):
            lateral_conv_spike = MultiSpike_norm4(T=self.T)
            lateral_conv = nn.Sequential(
                nn.Conv2d(in_channels[i], feat_channels, kernel_size=1, stride=1),
                nn.BatchNorm2d(feat_channels)
            )
            output_conv_spike = MultiSpike_norm4(T=self.T)
            output_conv = nn.Sequential(
                nn.Conv2d(feat_channels, feat_channels, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(feat_channels)
            )
            # lateral_conv = ConvModule(
            #     in_channels[i],
            #     feat_channels,
            #     kernel_size=1,
            #     bias=self.use_bias,
            #     norm_cfg=norm_cfg,
            #     act_cfg=None)

            # NOTE: 这里带ReLU
            # output_conv = ConvModule(
            #     feat_channels,
            #     feat_channels,
            #     kernel_size=3,
            #     stride=1,
            #     padding=1,
            #     bias=self.use_bias,
            #     norm_cfg=norm_cfg,
            #     act_cfg=act_cfg)
            self.lateral_convs_spikes.append(lateral_conv_spike)
            self.lateral_convs.append(lateral_conv)
            self.output_convs_spikes.append(output_conv_spike)
            self.output_convs.append(output_conv)

        self.out_spike = MultiSpike_norm4(T=self.T)
        self.mask_feature = Conv2d(
            feat_channels, out_channels, kernel_size=1, stride=1, padding=0)

        self.num_outs = num_outs
        self.point_generator = MlvlPointGenerator(strides)

    def init_weights(self) -> None:
        """Initialize weights."""
        for i in range(0, self.num_encoder_levels):
            # import pdb; pdb.set_trace()
            xavier_init(
                # self.input_convs[i][1].conv,
                self.input_convs[i][0],
                gain=1,
                bias=0,
                distribution='uniform')

        for i in range(0, self.num_input_levels - self.num_encoder_levels):
            caffe2_xavier_init(self.lateral_convs[i][0], bias=0)
            caffe2_xavier_init(self.output_convs[i][0], bias=0)
            # caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
            # caffe2_xavier_init(self.output_convs[i].conv, bias=0)

        caffe2_xavier_init(self.mask_feature, bias=0)

        normal_init(self.level_encoding, mean=0, std=1)
        for p in self.encoder.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

        # init_weights defined in MultiScaleDeformableAttention
        for m in self.encoder.layers.modules():
            if isinstance(m, MultiScaleDeformableAttention):
                m.init_weights()

    def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Args:
            feats (list[Tensor]): Feature maps of each level. Each has
                shape of (batch_size, c, h, w).

        Returns:
            tuple: A tuple containing the following:

                - mask_feature (Tensor): shape (batch_size, c, h, w).
                - multi_scale_features (list[Tensor]): Multi scale \
                        features, each in shape (batch_size, c, h, w).
        """
        # generate padding mask for each level, for each image
        # (batch_size, c, h, w) for i in each feature maps
        # (batch_size, c, h, w) - > (T, batch_size, c, h, w)
        # import pdb; pdb.set_trace()
        batch_size = feats[0].shape[0]
        encoder_input_list = []
        padding_mask_list = []
        level_positional_encoding_list = []
        spatial_shapes = []
        reference_points_list = []
        for i in range(self.num_encoder_levels):
            level_idx = self.num_input_levels - i - 1
            feat = feats[level_idx]
            # import pdb;
            # pdb.set_trace()
            # DONE: Check Dataflow
            feat = self.input_convs_spike[i](feat)
            feat_projected = self.input_convs[i](feat)
            h, w = feat.shape[-2:]

            # no padding
            padding_mask_resized = feat.new_zeros(
                (batch_size,) + feat.shape[-2:], dtype=torch.bool)
            pos_embed = self.postional_encoding(padding_mask_resized)
            level_embed = self.level_encoding.weight[i]
            level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
            # (h_i * w_i, 2)
            reference_points = self.point_generator.single_level_grid_priors(
                feat.shape[-2:], level_idx, device=feat.device)
            # normalize
            factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
            reference_points = reference_points / factor

            # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
            feat_projected = feat_projected.flatten(2).permute(0, 2, 1)
            level_pos_embed = level_pos_embed.flatten(2).permute(0, 2, 1)
            padding_mask_resized = padding_mask_resized.flatten(1)

            encoder_input_list.append(feat_projected)
            padding_mask_list.append(padding_mask_resized)
            level_positional_encoding_list.append(level_pos_embed)
            spatial_shapes.append(feat.shape[-2:])
            reference_points_list.append(reference_points)
        # shape (batch_size, total_num_queries),
        # total_num_queries=sum([., h_i * w_i,.])
        padding_masks = torch.cat(padding_mask_list, dim=1)
        # shape (total_num_queries, batch_size, c)
        encoder_inputs = torch.cat(encoder_input_list, dim=1)
        level_positional_encodings = torch.cat(level_positional_encoding_list, dim=1)
        device = encoder_inputs.device
        # shape (num_encoder_levels, 2), from low
        # resolution to high resolution
        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=device)
        # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
        level_start_index = torch.cat((spatial_shapes.new_zeros(
            (1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
        reference_points = torch.cat(reference_points_list, dim=0)
        reference_points = reference_points[None, :, None].repeat(
            batch_size, 1, self.num_encoder_levels, 1)
        valid_radios = reference_points.new_ones(
            (batch_size, self.num_encoder_levels, 2))
        # shape (num_total_queries, batch_size, c)
        # import pdb; pdb.set_trace()
        memory = self.encoder(
            query=encoder_inputs,
            query_pos=level_positional_encodings,
            key_padding_mask=padding_masks,
            spatial_shapes=spatial_shapes,
            reference_points=reference_points,
            level_start_index=level_start_index,
            valid_ratios=valid_radios)
        # (batch_size, c, num_total_queries)
        # memory = encoder_inputs
        memory = memory.permute(0, 2, 1)

        # from low resolution to high resolution
        num_queries_per_level = [e[0] * e[1] for e in spatial_shapes]
        outs = torch.split(memory, num_queries_per_level, dim=-1)
        outs = [
            x.reshape(batch_size, -1, spatial_shapes[i][0],
                      spatial_shapes[i][1]) for i, x in enumerate(outs)
        ]
        # 只对最后一个做处理
        for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
            # import pdb; pdb.set_trace()
            x = feats[i]
            x = self.lateral_convs_spikes[i](x)
            cur_feat = self.lateral_convs[i](x)

            y = cur_feat + F.interpolate(
                outs[-1],
                size=cur_feat.shape[-2:],
                mode='bilinear',
                align_corners=False)

            y = self.output_convs_spikes[i](y)
            y = self.output_convs[i](y)
            outs.append(y)
        multi_scale_features = outs[:self.num_outs]

        mask_feature = self.mask_feature(outs[-1])
        return mask_feature, multi_scale_features

#
#
# # Copyright (c) OpenMMLab. All rights reserved.
# from typing import List, Tuple, Union
#
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from mmcv.cnn import Conv2d, ConvModule
# from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
# from mmdet.models.layers.transformer.mmcv_spike.multi_scale_deform_attn import SpikeMultiScaleDeformableAttention
# from mmengine.model import (BaseModule, ModuleList, caffe2_xavier_init,
#                             normal_init, xavier_init)
# from torch import Tensor
#
# from mmdet.registry import MODELS
# from mmdet.utils import ConfigType, OptMultiConfig
# from ..task_modules.prior_generators import MlvlPointGenerator
# from .positional_encoding import SinePositionalEncoding
# from .transformer import Mask2FormerTransformerEncoder, Spike2FormerTransformerEncoder
# from mmdet.models.utils.Qtrick import MultiSpike_norm4
#
# @MODELS.register_module()
# class MSDeformAttnPixelDecoder(BaseModule):
#     """Pixel decoder with multi-scale deformable attention.
#
#     Args:
#         in_channels (list[int] | tuple[int]): Number of channels in the
#             input feature maps.
#         strides (list[int] | tuple[int]): Output strides of feature from
#             backbone.
#         feat_channels (int): Number of channels for feature.
#         out_channels (int): Number of channels for output.
#         num_outs (int): Number of output scales.
#         norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
#             Defaults to dict(type='GN', num_groups=32).
#         act_cfg (:obj:`ConfigDict` or dict): Config for activation.
#             Defaults to dict(type='ReLU').
#         encoder (:obj:`ConfigDict` or dict): Config for transformer
#             encoder. Defaults to None.
#         positional_encoding (:obj:`ConfigDict` or dict): Config for
#             transformer encoder position encoding. Defaults to
#             dict(num_feats=128, normalize=True).
#         init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
#             dict], optional): Initialization config dict. Defaults to None.
#     """
#
#     def __init__(self,
#                  in_channels: Union[List[int],
#                  Tuple[int]] = [256, 512, 1024, 2048],
#                  strides: Union[List[int], Tuple[int]] = [4, 8, 16, 32],
#                  feat_channels: int = 256,
#                  out_channels: int = 256,
#                  num_outs: int = 3,
#                  norm_cfg: ConfigType = dict(type='GN', num_groups=32),
#                  act_cfg: ConfigType = dict(type='ReLU'),
#                  encoder: ConfigType = None,
#                  positional_encoding: ConfigType = dict(
#                      num_feats=128, normalize=True),
#                  init_cfg: OptMultiConfig = None) -> None:
#         super().__init__(init_cfg=init_cfg)
#         self.strides = strides
#         self.num_input_levels = len(in_channels)
#         self.num_encoder_levels = \
#             encoder.layer_cfg.self_attn_cfg.num_levels
#         assert self.num_encoder_levels >= 1, \
#             'num_levels in attn_cfgs must be at least one'
#         input_conv_list = []
#         # from top to down (low to high resolution)
#         for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1):
#             input_conv = ConvModule(
#                 in_channels[i],
#                 feat_channels,
#                 kernel_size=1,
#                 norm_cfg=norm_cfg,
#                 act_cfg=None,
#                 bias=True)
#             input_conv_list.append(input_conv)
#         self.input_convs = ModuleList(input_conv_list)
#
#         self.encoder = Mask2FormerTransformerEncoder(**encoder)
#         self.postional_encoding = SinePositionalEncoding(**positional_encoding)
#         # high resolution to low resolution
#         self.level_encoding = nn.Embedding(self.num_encoder_levels,
#                                            feat_channels)
#
#         # fpn-like structure
#         self.lateral_convs = ModuleList()
#         self.output_convs = ModuleList()
#         self.use_bias = norm_cfg is None
#         # from top to down (low to high resolution)
#         # fpn for the rest features that didn't pass in encoder
#         for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
#                        -1):
#             lateral_conv = ConvModule(
#                 in_channels[i],
#                 feat_channels,
#                 kernel_size=1,
#                 bias=self.use_bias,
#                 norm_cfg=norm_cfg,
#                 act_cfg=None)
#             output_conv = ConvModule(
#                 feat_channels,
#                 feat_channels,
#                 kernel_size=3,
#                 stride=1,
#                 padding=1,
#                 bias=self.use_bias,
#                 norm_cfg=norm_cfg,
#                 act_cfg=act_cfg)
#             self.lateral_convs.append(lateral_conv)
#             self.output_convs.append(output_conv)
#
#         self.mask_feature = Conv2d(
#             feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
#
#         self.num_outs = num_outs
#         self.point_generator = MlvlPointGenerator(strides)
#
#     def init_weights(self) -> None:
#         """Initialize weights."""
#         for i in range(0, self.num_encoder_levels):
#             xavier_init(
#                 self.input_convs[i].conv,
#                 gain=1,
#                 bias=0,
#                 distribution='uniform')
#
#         for i in range(0, self.num_input_levels - self.num_encoder_levels):
#             caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
#             caffe2_xavier_init(self.output_convs[i].conv, bias=0)
#
#         caffe2_xavier_init(self.mask_feature, bias=0)
#
#         normal_init(self.level_encoding, mean=0, std=1)
#         for p in self.encoder.parameters():
#             if p.dim() > 1:
#                 nn.init.xavier_normal_(p)
#
#         # init_weights defined in MultiScaleDeformableAttention
#         for m in self.encoder.layers.modules():
#             if isinstance(m, MultiScaleDeformableAttention):
#                 m.init_weights()
#
#     def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
#         """
#         Args:
#             feats (list[Tensor]): Feature maps of each level. Each has
#                 shape of (batch_size, c, h, w).
#
#         Returns:
#             tuple: A tuple containing the following:
#
#                 - mask_feature (Tensor): shape (batch_size, c, h, w).
#                 - multi_scale_features (list[Tensor]): Multi scale \
#                         features, each in shape (batch_size, c, h, w).
#         """
#         # generate padding mask for each level, for each image
#         # (batch_size, c, h, w) for i in each feature maps
#         # (batch_size, c, h, w) - > (T, batch_size, c, h, w)
#         import pdb; pdb.set_trace()
#         batch_size = feats[0].shape[0]
#         encoder_input_list = []
#         padding_mask_list = []
#         level_positional_encoding_list = []
#         spatial_shapes = []
#         reference_points_list = []
#         for i in range(self.num_encoder_levels):
#             level_idx = self.num_input_levels - i - 1
#             feat = feats[level_idx]
#             feat_projected = self.input_convs[i](feat)
#             h, w = feat.shape[-2:]
#
#             # no padding
#             padding_mask_resized = feat.new_zeros(
#                 (batch_size,) + feat.shape[-2:], dtype=torch.bool)
#             pos_embed = self.postional_encoding(padding_mask_resized)
#             level_embed = self.level_encoding.weight[i]
#             level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
#             # (h_i * w_i, 2)
#             reference_points = self.point_generator.single_level_grid_priors(
#                 feat.shape[-2:], level_idx, device=feat.device)
#             # normalize
#             factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
#             reference_points = reference_points / factor
#
#             # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
#             feat_projected = feat_projected.flatten(2).permute(0, 2, 1)
#             level_pos_embed = level_pos_embed.flatten(2).permute(0, 2, 1)
#             padding_mask_resized = padding_mask_resized.flatten(1)
#
#             encoder_input_list.append(feat_projected)
#             padding_mask_list.append(padding_mask_resized)
#             level_positional_encoding_list.append(level_pos_embed)
#             spatial_shapes.append(feat.shape[-2:])
#             reference_points_list.append(reference_points)
#         # shape (batch_size, total_num_queries),
#         # total_num_queries=sum([., h_i * w_i,.])
#         padding_masks = torch.cat(padding_mask_list, dim=1)
#         # shape (total_num_queries, batch_size, c)
#         encoder_inputs = torch.cat(encoder_input_list, dim=1)
#         level_positional_encodings = torch.cat(level_positional_encoding_list, dim=1)
#         device = encoder_inputs.device
#         # shape (num_encoder_levels, 2), from low
#         # resolution to high resolution
#         spatial_shapes = torch.as_tensor(
#             spatial_shapes, dtype=torch.long, device=device)
#         # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
#         level_start_index = torch.cat((spatial_shapes.new_zeros(
#             (1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
#         reference_points = torch.cat(reference_points_list, dim=0)
#         reference_points = reference_points[None, :, None].repeat(
#             batch_size, 1, self.num_encoder_levels, 1)
#         valid_radios = reference_points.new_ones(
#             (batch_size, self.num_encoder_levels, 2))
#         # shape (num_total_queries, batch_size, c)
#         memory = self.encoder(
#             query=encoder_inputs,
#             query_pos=level_positional_encodings,
#             key_padding_mask=padding_masks,
#             spatial_shapes=spatial_shapes,
#             reference_points=reference_points,
#             level_start_index=level_start_index,
#             valid_ratios=valid_radios)
#         # (batch_size, c, num_total_queries)
#         memory = memory.permute(0, 2, 1)
#
#         # from low resolution to high resolution
#         num_queries_per_level = [e[0] * e[1] for e in spatial_shapes]
#         outs = torch.split(memory, num_queries_per_level, dim=-1)
#         outs = [
#             x.reshape(batch_size, -1, spatial_shapes[i][0],
#                       spatial_shapes[i][1]) for i, x in enumerate(outs)
#         ]
#         # 只对最后一个做处理
#         for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
#             # import pdb; pdb.set_trace()
#             x = feats[i]
#             cur_feat = self.lateral_convs[i](x)
#             y = cur_feat + F.interpolate(
#                 outs[-1],
#                 size=cur_feat.shape[-2:],
#                 mode='bilinear',
#                 align_corners=False)
#             y = self.output_convs[i](y)
#             outs.append(y)
#         multi_scale_features = outs[:self.num_outs]
#
#         mask_feature = self.mask_feature(outs[-1])
#         return mask_feature, multi_scale_features
